import mimetypes
import os
from typing import Union
import base64
import time
import requests
import random
from tqdm import tqdm
# import transformers
from PIL import Image
from tqdm import tqdm
import sys

# from otter.modeling_otter import OtterForConditionalGeneration

import json

# PROMPT = """
# You are given a pair of very similar images. In image 2, there is a specific object that is missing or changed from image 1. Generate 3 questions about that object.

# There are a few rules to follow for each question:
# 1. The question should be answerable for image 1, that is there is a definitive answer to the question, just by looking at image 1. 
# 2. The question should not be answerable for image 2. "Not answerable" means, just by looking at image 2, the answer would be something like "I don't know" or "I don't see it".
# 3. The question should be relevant to the content of each image alone, even without seeing the other image. 

# The response should be formatted as:
# - Q1: <question>
#   A1: <answer for image 1>
# - Q2: <question>
#   A2: <answer for image 1>
# - Q3: <question>
#   A3: <answer for image 1>
# """


PROMPT = """
You are given a pair of very similar images. In image 2, there is a specific object that is missing or changed from image 1. Generate a question that is answerable for image 1 while not answerable for image 2.

There are a few rules to follow for each question:
1. The question should be answerable for image 1, that is there is a definitive answer to the question, just by looking at image 1. 
2. The question should not be answerable for image 2. "Not answerable" means, just by looking at image 2, the answer would be something like "I don't know", "I don't see SOMETHING" or "Nothing". For example,
    - If the question is "What color is the car?", and there is no car in image 2, the answer should be "I don't see a car".
    - If the question is "What is on the man's head", and there is nothing on the man's head in image 2, tha answer should be "Nothing".
    - If the question is asking about something that cannot be seen clearly in image 2, the answer should be "I don't know".
    - Try not to ask questions about the presence of an object, but rather about the properties of the object. For example, instead of asking "Is there a car in the image?", ask "What color is the car?". Instead of asking "How many people are there?", ask "What is the person wearing?".
3. The question should be relevant to the content of each image alone, even without seeing the other image. 

The response should be formatted as:
- Q: <question>
- A1: <answer for image 1>
- A2: <answer for image 2, choose your answer from "I don't know", "I don't see xxx" or "Nothing". Try not to refer to the answer for image 1>
"""

# Disable warnings
requests.packages.urllib3.disable_warnings()

# ------------------- Utility Functions -------------------


def get_content_type(file_path):
    content_type, _ = mimetypes.guess_type(file_path)
    return content_type


# ------------------- Image and Video Handling Functions -------------------


def get_image(url: str) -> Union[Image.Image, list]:
    if "://" not in url:  # Local file
        content_type = get_content_type(url)
    else:  # Remote URL
        content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")

    if "image" in content_type:
        if "://" not in url:  # Local file
            return Image.open(url)
        else:  # Remote URL
            return Image.open(requests.get(url, stream=True, verify=False).raw)
    else:
        raise ValueError("Invalid content type. Expected image or video.")


def gptv_query_paired_image(orig_img, modified_img, prompt, temp=0.):

    GPT4V_KEY = os.environ.get("GPT4V_KEY")
    headers = {
        "Content-Type": "application/json",
        "api-key": GPT4V_KEY,
        'cogsvc-openai-gptv-disable-faceblur': 'true',
    }

    data = {
        'max_tokens':2048, 
        'temperature': temp,
        'top_p': 0.5,
        'messages':[]
    }
    data['messages'] = [{"role": "user", "content": []}] ##  But they do not contain the final answer so do not output them as part of the final answer.

    ## Query Prompt
    data['messages'][-1]["content"].append(prompt)

    assert orig_img is not None and modified_img is not None
    assert os.path.exists(orig_img) and os.path.exists(modified_img)
    data['messages'][-1]["content"].append('Image 1:')
    data['messages'][-1]["content"].append({'image': base64.b64encode(open(orig_img, "rb").read()).decode()})
    data['messages'][-1]["content"].append('Image 2:')
    data['messages'][-1]["content"].append({'image': base64.b64encode(open(modified_img, "rb").read()).decode()})


    response_text, retry, response_json, regular_time = '', 0, None, 60
    retry_400 = 0
    while len(response_text)<2:
        retry += 1
        try:
            response = requests.post(os.environ["GPTV_API_BASE"], headers=headers, data=json.dumps(data)) 
            response_json = response.json()
        except Exception as e:
            print(e)
            time.sleep(regular_time)
            continue
        if response.status_code != 200:
            print(response.headers,response.content)
            print(orig_img, modified_img)
            if random.random()<1: print(f"The response status code for is {response.status_code} (Not OK)")
            data['temperature'] = min(data['temperature'] + 0.1, 1.0)
            if response.status_code == 400:
                retry_400 += 1
                if retry_400 >= 3:
                    return None
            elif response.status_code == 429:
                time.sleep(regular_time)
            continue
        if 'choices' not in response_json:
            time.sleep(regular_time)
            continue
        # response_text = response_json["choices"][0]["text"]
        response_text = response_json["choices"][0]["message"]["content"]
        if response_text.lower().strip().startswith("I'm sorry"):
            time.sleep(regular_time)
            continue
    # return [response_json["choices"][0]["text"]]
    return response_json["choices"][0]["message"]["content"]

# ------------------- Main Function -------------------

def test_single(prompt=PROMPT, orig_img=None, mask_img=None, temp=0):
    
    # print(prompt)
    
    response = gptv_query_paired_image(orig_img, mask_img, prompt, temp=temp)
    # print(f"Response:\n\t\t{response}")
    return response


def debug_vqa_gen(mask_img_folder, annotation_file, output_folder, orig_img_folder, debug=False, temp=0.7, overwrite=False):
    import shutil
    os.makedirs(output_folder, exist_ok=True)
    annotations = [json.loads(line) for line in open(annotation_file)]
    orig_image_files = [os.path.join(orig_img_folder, f"{annotation['file_id']}.jpg") for annotation in annotations]
    mask_image_files = [os.path.join(mask_img_folder, f"{os.path.basename(annotation['after']['image_path'])}") for annotation in annotations]
    # questions = [annotation['question'] for annotation in annotations]
    if debug:
        mask_image_files = mask_image_files[:100]
    
    for idx, img_file in tqdm(enumerate(mask_image_files), total=len(mask_image_files)):
        im_name = os.path.basename(img_file)
        output_file = os.path.join(output_folder, im_name.replace(".png", ".txt"))
        orig_img = orig_image_files[idx]
        if not os.path.exists(os.path.join(output_folder, os.path.basename(orig_img))) or overwrite:
            shutil.copyfile(orig_img, os.path.join(output_folder, os.path.basename(orig_img)))
        if not os.path.exists(os.path.join(output_folder, im_name)) or overwrite:
            shutil.copyfile(img_file, os.path.join(output_folder, im_name))
        if os.path.exists(output_file) and os.path.getsize(output_file) > 0 and not overwrite:
            continue
        response = test_single(orig_img=orig_img, mask_img=img_file, temp=temp)
        with open(output_file, "w") as f:
            f.write(response)


def vqa_gen_1k(mask_img_folder, annotation_file, output_folder, orig_img_folder, debug=False, temp=0.7, overwrite=False):
    import shutil
    os.makedirs(output_folder, exist_ok=True)
    annotations = [json.loads(line) for line in open(annotation_file)]
    orig_image_files = []
    for annotation in annotations:
        if "COCO" in annotation['file_id']:
            sub_folder = annotation['file_id'].split("_")[1]
            image_path = os.path.join(orig_img_folder, sub_folder, annotation['file_id'] + '.jpg')
        else:
            image_path = os.path.join(orig_img_folder, annotation['file_id'] + '.jpg')
        orig_image_files.append(image_path)
    # orig_image_files = [os.path.join(orig_img_folder, f"{annotation['file_id']}.jpg") for annotation in annotations]
    #data/vqav2/images/remove_anything/lama/COCO_val2014_000000000810-810001_remove_0.png
    mask_image_files = [os.path.join(mask_img_folder, f"{annotation['file_id']}-{annotation['question_id']}_remove_0.png") for annotation in annotations]
    # questions = [annotation['question'] for annotation in annotations]
    if debug:
        mask_image_files = mask_image_files[:100]
    
    for idx, img_file in tqdm(enumerate(mask_image_files), total=len(mask_image_files)):
        im_name = os.path.basename(img_file)
        output_file = os.path.join(output_folder, im_name.replace(".png", ".txt"))
        orig_img = orig_image_files[idx]
        if not os.path.exists(orig_img):
            continue
        if not os.path.exists(img_file):
            continue
        if not os.path.exists(os.path.join(output_folder, os.path.basename(orig_img))) or overwrite:
            shutil.copyfile(orig_img, os.path.join(output_folder, os.path.basename(orig_img)))
        if not os.path.exists(os.path.join(output_folder, im_name)) or overwrite:
            shutil.copyfile(img_file, os.path.join(output_folder, im_name))
        if os.path.exists(output_file) and os.path.getsize(output_file) > 0 and not overwrite:
            continue
        response = test_single(orig_img=orig_img, mask_img=img_file, temp=temp)
        if response is None:
            continue
        with open(output_file, "w") as f:
            f.write(response)


if __name__ == "__main__":
    import fire
    fire.Fire()
